import torch

DATA_ROOT = '../../data/DailyDialog/' # the parent root where your train/val/test data are stored
MODEL_ROOT = '../../model/dialog/' # the root to buffer your checkpoints

special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

MAX_LENGTH = 100
LR = lr = 0.1 # initial LR

device = DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# num_steps输入为句子长度
# 为实际句子长度+1 (inc中的eos和dec中的bos)

# model dim 512, ffn dim 1024, layers 6, attention head 4

num_hiddens, num_layers, dropout, batch_size, num_steps = 512, 6, 0.1, 64, 25

ffn_num_input, ffn_num_hiddens, num_heads = 512, 1024, 4
key_size, query_size, value_size = 128, 128, 128
norm_shape = [512] # 即hidden维数
num_epochs = 20
num_examples = ""


# num_hiddens, num_layers, dropout, batch_size, num_steps = 16, 2, 0.1, 64, 25

# ffn_num_input, ffn_num_hiddens, num_heads = 16, 16, 2
# key_size, query_size, value_size = 8, 8, 8
# norm_shape = [16] # 即hidden维数
# num_epochs = 20
# num_examples = ""